import numpy as np
import pandas as pd
import copy
from statistics import mean

N_STEPS_STR = "Mean number of steps among succeeded optimisations"
N_E_CALC_STR = "Mean number of energy calculations among succeeded optimisations"
FAILURE_RATE_STR = "Failure rate (%)"
TIME_PER_STR_STR = "Average time per structure (sec)"
AVE_FINAL_E_STR = "Average final energy (eV/atom)"

FONT_SIZE = 18

stat_keys = [N_STEPS_STR,
            N_E_CALC_STR,
            FAILURE_RATE_STR,
            TIME_PER_STR_STR,
            AVE_FINAL_E_STR]

def append_df_to_dict(dict, df, i):

    for key in dict.keys():
        dict[key].append(df[key][i])


def mean_std_str(values, round_v=6):

    if round_v == 0:
        return str(int(round(mean(values), round_v))) + "(" + str(int(round(np.std(values)/np.sqrt(len(values)), round_v))) + ")"

    return str(round(mean(values), round_v)) + "(" + str(round(np.std(values)/np.sqrt(len(values)), round_v)) + ")"


def save_stats(df, output_dir):

    empty_rows = {"Step": [], "E_step": [], "Time": [], "Energy": [], "Fmax": []}
    total_rows = copy.deepcopy(empty_rows)
    added_time = copy.deepcopy(empty_rows)
    added_steps = copy.deepcopy(empty_rows)
    combined_last_steps = copy.deepcopy(empty_rows)
    last_steps = copy.deepcopy(empty_rows)

    stat = {}
    for k in stat_keys:
        stat[k] = []

    steps = []

    first_steps = copy.deepcopy(empty_rows)
    steps_nums = []
    last_steps = copy.deepcopy(empty_rows)

    total_time = 0
    total_time_list = []
    
    df = df[df["Fmax"] != "Fmax"]

    if 'E_step' not in df:
        df["E_step"] = df["Step"]

    df = df.astype({"Fmax": float, "Time": float, "Energy": float, "Step": int, "E_step": int})
    df = df.to_dict(orient='list')

    max_steps = 1000
    if min(df["Fmax"]) < 0.05:
        max_time = int(max([df["Time"][i] for i, x in enumerate(df["Fmax"]) if x < 0.05]))
    else:
        max_time = int(max(df["Time"]))

    fails_num = 0

    new_rows = copy.deepcopy(empty_rows)
    new_added_time = copy.deepcopy(empty_rows)

    append_df_to_dict(total_rows, df, 1)
    append_df_to_dict(first_steps, df, 1)

    for i in range(1, len(df["Fmax"])):

        # if step is 0, we add the first step and continue
        if df["Step"][i] == 0:

            new_rows = copy.deepcopy(empty_rows)
            new_added_time = copy.deepcopy(empty_rows)

            append_df_to_dict(new_rows, df, i)
            append_df_to_dict(first_steps, df, i)
            
            continue

        # if time difference is negative, the step moved to the next day
        # so we add 86400 seconds (24 hours) to the time
        if df["Time"][i] < 0:
            # df["Time"].iloc[i] += 86400
            df["Time"][i]+= 86400

        time_dif = int(df["Time"][i] - df["Time"][i - 1])
        if time_dif > 1:

            for j in range(1, time_dif):

                append_df_to_dict(new_added_time, df, i - 1)
                new_added_time["Time"][-1] += j

        append_df_to_dict(new_rows, df, i)

        # check if the step is the last one in the trajectory
        is_last = False

        if i < len(df["Fmax"]) - 1:
            if df["Step"][i + 1] == 0:
                is_last = True
        else:
            is_last = True

        if is_last:

            total_time += float(df["Time"][i])
            total_time_list.append(float(df["Time"][i]))

            steps.append(df["Step"][i])
        
            if df["Fmax"][i] >= 0.05:
                
                print(f"Step {df['Step'][i]} row {i} has Fmax > 0.05: {df['Fmax'][i]} and energy {df['Energy'][i]}")
                fails_num += 1

            else:

                append_df_to_dict(combined_last_steps, df, i)
                steps_nums.append(combined_last_steps["Step"][-1])

                append_df_to_dict(last_steps, df, i)

                for key in new_rows.keys():

                    total_rows[key].extend(new_rows[key])
                    added_time[key].extend(new_added_time[key])

                if df['Step'][i] < max_steps:

                    for j in range(df['Step'][i] + 1, max_steps + 1):
                        
                        append_df_to_dict(added_steps, df, i)
                        added_steps["Step"][-1] = j

                if df["Time"][i] < max_time:

                    for j in range(int(df["Time"][i]) + 1, max_time + 1):

                        append_df_to_dict(added_time, df, i)
                        added_time["Time"][-1] = j

    if len(last_steps["Step"]) > 0:

        total_opt_n = len(last_steps["Time"]) + fails_num

        stat[N_STEPS_STR].append(mean_std_str(last_steps["Step"], round_v=0))
        stat[N_E_CALC_STR].append(mean_std_str(last_steps["E_step"], round_v=0))
        stat[FAILURE_RATE_STR].append(round(fails_num / total_opt_n * 100, 2))
        stat[TIME_PER_STR_STR].append(mean_std_str(total_time_list, round_v=0))
        stat[AVE_FINAL_E_STR].append(mean_std_str(last_steps["Energy"], 6))

    else:

        stat[N_STEPS_STR].append("0")
        stat[N_E_CALC_STR].append("0")
        stat[FAILURE_RATE_STR].append(100)
        stat[TIME_PER_STR_STR].append(round(total_time / (fails_num + len(last_steps["Time"])), 0))
        stat[AVE_FINAL_E_STR].append("0")

    print(stat)

    pd.DataFrame(stat).to_csv(f"{output_dir}/stat.csv", index=False)

    return
